# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Functional operations.

See the @{$python/functional_ops} guide.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
# pylint: disable=unused-import
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
@tf_export("foldl")
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
          swap_memory=False, name=None):
  """foldl on the list of tensors unpacked from `elems` on dimension 0.

  This foldl operator repeatedly applies the callable `fn` to a sequence
  of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn. If `initializer` is None, `elems` must contain
  at least one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is fn(initializer, values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Args:
    fn: The callable to be performed.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      as the initial value for the accumulator.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors, resulting from applying
    `fn` consecutively to the list of tensors unpacked from `elems`, from first
    to last.

  Raises:
    TypeError: if `fn` is not callable.

  Example:
    ```python
    elems = [1, 2, 3, 4, 5, 6]
    sum = foldl(lambda a, x: a + x, elems)
    # sum == 21
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  def create_ta(elem):
    return tensor_array_ops.TensorArray(
        dtype=elem.dtype, size=n, dynamic_size=False,
        infer_shape=True).unstack(elem)

  in_graph_mode = not context.executing_eagerly()
  with ops.name_scope(name, "foldl", [elems]):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    # Convert elems to tensor array. n may be known statically.
    elems_flat = [
        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
    ]
    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]

    elems_ta = nest.map_structure(create_ta, elems)

    if initializer is None:
      a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
      i = constant_op.constant(1)
    else:
      a = initializer
      i = constant_op.constant(0)

    def compute(i, a):
      elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
      a = fn(a, elem_i)
      return [i + 1, a]

    _, r_a = control_flow_ops.while_loop(
        lambda i, a: i < n, compute, [i, a],
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory)

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    return r_a


@tf_export("foldr")
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
          swap_memory=False, name=None):
  """foldr on the list of tensors unpacked from `elems` on dimension 0.

  This foldr operator repeatedly applies the callable `fn` to a sequence
  of elements from last to first. The elements are made of the tensors
  unpacked from `elems`. The callable fn takes two tensors as arguments.
  The first argument is the accumulated value computed from the preceding
  invocation of fn. If `initializer` is None, `elems` must contain at least
  one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `fn(initializer, values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Args:
    fn: The callable to be performed.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      as the initial value for the accumulator.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors, resulting from applying
    `fn` consecutively to the list of tensors unpacked from `elems`, from last
    to first.

  Raises:
    TypeError: if `fn` is not callable.

  Example:
    ```python
    elems = [1, 2, 3, 4, 5, 6]
    sum = foldr(lambda a, x: a + x, elems)
    # sum == 21
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  def create_ta(elem):
    return tensor_array_ops.TensorArray(
        dtype=elem.dtype, size=n, dynamic_size=False,
        infer_shape=True).unstack(elem)

  in_graph_mode = not context.executing_eagerly()
  with ops.name_scope(name, "foldr", [elems]):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally and not
      # issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    # Convert elems to tensor array. n may be known statically.
    elems_flat = [
        ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
    ]
    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]

    elems_ta = nest.map_structure(create_ta, elems)

    if initializer is None:
      i = n - 1
      a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
    else:
      i = n
      a = initializer

    def compute(i, a):
      i -= 1
      elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
      a_out = fn(a, elem)
      return [i, a_out]

    _, r_a = control_flow_ops.while_loop(
        lambda i, a: i > 0,
        compute, [i, a],
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory)

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    return r_a


@tf_export("map_fn")
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
           swap_memory=False, infer_shape=True, name=None):
  """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will
      have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `dtype` if one is provided, otherwise
      it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  Examples:
    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  if isinstance(elems, sparse_tensor.SparseTensor):
    raise TypeError(
        "To perform a map on the values of a sparse tensor use either "
        " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
        " SparseTensor(input.indices, map_fn(fn, input.values), "
        "input.dense_shape)")

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  if dtype is None:
    output_is_sequence = input_is_sequence
    output_flatten = input_flatten
    output_pack = input_pack
  else:
    output_is_sequence = nest.is_sequence(dtype)
    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
    def output_pack(x):
      return (nest.pack_sequence_as(dtype, x)
              if output_is_sequence else x[0])

  elems_flat = input_flatten(elems)

  in_graph_mode = not context.executing_eagerly()
  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]

    dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
    dtype_flat = output_flatten(dtype)

    # Convert elems to tensor array. n may be known statically.
    static_shape = elems_flat[0].shape
    if static_shape.ndims is not None and static_shape.ndims < 1:
      if len(elems_flat) == 1:
        raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
      else:
        raise ValueError(
            "elements in elems must be 1+ dimensional Tensors, not scalars"
        )
    n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]

    # TensorArrays are always flat
    elems_ta = [
        tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
                                     dynamic_size=False,
                                     infer_shape=True)
        for elem in elems_flat]
    # Unpack elements
    elems_ta = [
        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]

    i = constant_op.constant(0)

    accs_ta = [
        tensor_array_ops.TensorArray(dtype=dt, size=n,
                                     dynamic_size=False,
                                     infer_shape=infer_shape)
        for dt in dtype_flat]

    def compute(i, tas):
      """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
      packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
      packed_fn_values = fn(packed_values)
      nest.assert_same_structure(dtype or elems, packed_fn_values)
      flat_fn_values = output_flatten(packed_fn_values)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)]
      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n, compute, (i, accs_ta),
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory)
    results_flat = [r.stack() for r in r_a]

    n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0]
    for elem in elems_flat[1:]:
      n_static.merge_with(elem.get_shape().with_rank_at_least(1)[0])
    for r in results_flat:
      r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
          r.get_shape()[1:]))

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    return output_pack(results_flat)


@tf_export("scan")
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
         swap_memory=False, infer_shape=True, name=None):
  """scan on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `scan` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn. If `initializer` is None, `elems` must contain
  at least one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.

  This method also allows multi-arity `elems` and accumulator.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The second argument of
  `fn` must match the structure of `elems`.

  If no `initializer` is provided, the output structure and dtypes of `fn`
  are assumed to be the same as its input; and in this case, the first
  argument of `fn` must match the structure of `elems`.

  If an `initializer` is provided, then the output of `fn` must have the same
  structure as `initializer`; and the first argument of `fn` must match
  this structure.

  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
   one that works in `python3`, is:
  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.

  Args:
    fn: The callable to be performed.  It accepts two arguments.  The first
      will have the same structure as `initializer` if one is provided,
      otherwise it will have the same structure as `elems`.  The second
      will have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `initializer` if one is provided,
      otherwise it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      initial value for the accumulator, and the expected output type of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, and the previous accumulator value(s), from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `initializer` do not match.
    ValueError: if the lengths of the output of `fn` and `initializer`
      do not match.

  Examples:
    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    sum = scan(lambda a, x: a + x, elems)
    # sum == [1, 3, 6, 10, 15, 21]
    ```

    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    initializer = np.array(0)
    sum_one = scan(
        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
    # sum_one == [1, 2, 3, 4, 5, 6]
    ```

    ```python
    elems = np.array([1, 0, 0, 0, 0, 0])
    initializer = (np.array(0), np.array(1))
    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  if initializer is None:
    output_is_sequence = input_is_sequence
    output_flatten = input_flatten
    output_pack = input_pack
  else:
    output_is_sequence = nest.is_sequence(initializer)
    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
    def output_pack(x):
      return (nest.pack_sequence_as(initializer, x)
              if output_is_sequence else x[0])

  elems_flat = input_flatten(elems)

  in_graph_mode = not context.executing_eagerly()
  with ops.name_scope(name, "scan", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    # Convert elems to tensor array.
    elems_flat = [
        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]

    # Convert elems to tensor array. n may be known statically.
    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]

    # TensorArrays are always flat
    elems_ta = [
        tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
                                     dynamic_size=False,
                                     infer_shape=True)
        for elem in elems_flat]
    # Unpack elements
    elems_ta = [
        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]

    if initializer is None:
      a_flat = [elem.read(0) for elem in elems_ta]
      i = constant_op.constant(1)
    else:
      initializer_flat = output_flatten(initializer)
      a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
      i = constant_op.constant(0)

    # Create a tensor array to store the intermediate values.
    accs_ta = [
        tensor_array_ops.TensorArray(
            dtype=init.dtype, size=n,
            element_shape=init.shape if infer_shape else None,
            dynamic_size=False,
            infer_shape=infer_shape)
        for init in a_flat]

    if initializer is None:
      accs_ta = [acc_ta.write(0, a) for (acc_ta, a) in zip(accs_ta, a_flat)]

    def compute(i, a_flat, tas):
      """The loop body of scan.

      Args:
        i: the loop counter.
        a_flat: the accumulator value(s), flattened.
        tas: the output accumulator TensorArray(s), flattened.

      Returns:
        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
          updated TensorArrays

      Raises:
        TypeError: if initializer and fn() output structure do not match
        ValueType: if initializer and fn() output lengths do not match
      """
      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
      packed_a = output_pack(a_flat)
      a_out = fn(packed_a, packed_elems)
      nest.assert_same_structure(
          elems if initializer is None else initializer, a_out)
      flat_a_out = output_flatten(a_out)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
      return (i + 1, flat_a_out, tas)

    _, _, r_a = control_flow_ops.while_loop(
        lambda i, _1, _2: i < n, compute, (i, a_flat, accs_ta),
        parallel_iterations=parallel_iterations,
        back_prop=back_prop, swap_memory=swap_memory,
        maximum_iterations=n)

    results_flat = [r.stack() for r in r_a]

    n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0]
    for elem in elems_flat[1:]:
      n_static.merge_with(elem.get_shape().with_rank_at_least(1)[0])
    for r in results_flat:
      r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
          r.get_shape()[1:]))

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    return output_pack(results_flat)


# pylint: disable=invalid-name
def If(cond, inputs, then_branch, else_branch, name=None):
  r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).

  Args:
    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
      converted to a boolean according to the following rule: if the
      scalar is a numerical value, non-zero means True and zero means
      False; if the scalar is a string, non-empty means True and empty
      means False.
    inputs: A list of input tensors.
    then_branch: A function takes 'inputs' and returns a list of tensors,
        whose types are the same as what else_branch returns.
    else_branch: A function takes 'inputs' and returns a list of tensors.
        whose types are the same as what then_branch returns.
    name: A name for the operation (optional).

  Returns:
    A list of tensors returned by either then_branch(inputs)
    or else_branch(inputs).
  """
  # pylint: disable=protected-access
  return gen_functional_ops._if(
      cond,
      inputs, [_.type for _ in then_branch.definition.signature.output_arg],
      then_branch,
      else_branch,
      name=name)


def Gradient(inputs, f, name=None):
  r"""Computes the gradient function for function f via backpropagation.

  Args:
    inputs: A list of tensors of size N + M.
    f: The function we want to compute the gradient for.

      The function 'f' must be a numerical function which takes N inputs and
      produces M outputs. Its gradient function 'g', which is  a function
      taking N + M inputs and produces N outputs.

      I.e. if we have
         (y1, y2, ..., yM) = f(x1, x2, ..., xN),
      then, g is
         (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN,
                                           dL/dy1, dL/dy2, ..., dL/dyM),

      where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
      loss function). dL/dxi is the partial derivative of L with respect
      to xi.

    name: A name for the operation (optional).

  Returns:
    A list of tensors of size N.
  """
  # TODO(zhifengc): Pretty-print the above spec in latex.
  # TODO(zhfiengc): Needs some math expert to say the comment above better.
  tlist = [_.type for _ in f.definition.signature.input_arg]
  return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)


# pylint: disable=invalid-name,protected-access
def While(input_, cond, body, name=None, hostmem=None):
  r"""output = input; While (Cond(output)) { output = Body(output) }.

  Args:
    input_: A list of `Tensor` objects.
      A list of input tensors whose types are T.
    cond: . A function takes 'input' and returns a tensor.  If the tensor is
      a scalar of non-boolean, the scalar is converted to a boolean
      according to the following rule: if the scalar is a numerical
      value, non-zero means True and zero means False; if the scalar is
      a string, non-empty means True and empty means False. If the
      tensor is not a scalar, non-emptiness means True and False
      otherwise.
    body: . A funcion takes a list of tensors and returns another
      list tensors. Both lists have the same types as specified
      by T.
    name: A name for the operation (optional).
    hostmem: A list of integer. If i is in the list, input[i] is a
      host memory tensor.

  Returns:
    A list of `Tensor` objects. Has the same type as `input`.
    A list of output tensors whose types are T.
  """
  ret = gen_functional_ops._while(input_, cond, body, name=name)
  if hostmem:
    input_attr = attr_value_pb2.AttrValue()
    input_attr.list.i.extend(hostmem)
    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access

    output_attr = attr_value_pb2.AttrValue()
    output_attr.list.i.extend(hostmem)
    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
  return ret


# b/36459430
#
# Ideally, we do not need this rewrite For loop into a While loop.
# However, today, if a While runs on GPU and the condition returns a
# boolean, the While kernel crashes. Even if we fix the crash, the
# bool needs to be copied between GPU and CPU. So, a for loop is much
# preferred when running on GPU.
#
# On the other hand, For op has no directly XLA kernel. So, when we run
# a for loop, we need to rewrite it using a While op.
#
# It should be possible and probably better to write a XLA C++ kernel
# implementing the logic in _ForUsingWhile.
def _ForUsingWhile(start,
                   limit,
                   delta,
                   inputs,
                   forbody,
                   name=None,
                   hostmem=None):
  """Helper to implement a For loop using a While."""
  # To support negative delta (e.g., range(100, 0, -3)), we iterate
  # over the range(n) and use iter * delta + start as the real
  # iteration index. (e.g., for i in range(34): iter = i * (-3) +
  # 100).
  d = math_ops.abs(delta)
  # XLA on TPUs doesn't support integer division
  n = math_ops.cast(
      math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
      math_ops.cast(d, dtypes.float32), dtypes.int32)

  # Carried loop variables ("extra_args") are implicitly added to the input list
  # of the WhileBody function. WhileCond does not call forbody, and so does not
  # depend on any of forbody's extra_args. Since WhileCond and WhileBody
  # must have identical inputs, we have to augment the cond signature to take
  # the same types as the carried loop variables.
  body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
  cond_sig = body_sig + [t.dtype for t in forbody.captured_inputs]

  cond_name = "%s_Cond" % forbody.name

  @function.Defun(*cond_sig, func_name=cond_name)
  def WhileCond(i, n, *args):
    del args
    return i < n

  body_name = "%s_Body" % forbody.name

  @function.Defun(*body_sig, func_name=body_name)
  def WhileBody(i, n, start, delta, *args):
    """A While wrapper for forbody that handles loop-carried captured inputs."""
    for_result = forbody(start + i * delta, *args)
    # Nullary functions return an Operation. Normal functions can't do this
    # because their return values are converted to Tensors.
    if isinstance(for_result, ops.Operation):
      for_result = ()
    # Unary functions return a single Tensor value.
    elif isinstance(for_result, ops.Tensor):
      for_result = (for_result,)
    extra_args = tuple(function.get_extra_args())
    return (i + 1, n, start, delta) + tuple(for_result) + extra_args

  if hostmem is not None:
    hostmem = [(4 + _) for _ in hostmem]

  results = While(
      input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs,
      cond=WhileCond,
      body=WhileBody,
      name=name,
      hostmem=hostmem)
  # Slice off the loop-carried captured inputs.
  return list(results[4:len(results) - len(WhileBody.captured_inputs)])


def For(start,
        limit,
        delta,
        inputs,
        body,
        name=None,
        hostmem=None,
        rewrite_with_while=None):
  r"""out = input; for i in range(start, limit, delta) out = body(i, out).

  Args:
    start: A `Tensor` of type `int32`.
    limit: A `Tensor` of type `int32`.
    delta: A `Tensor` of type `int32`.
    inputs: A list of `Tensor` objects.
      A list of input tensors whose types are T.
    body: A function takes a list of tensors and returns another
      list of tensors. Both lists have the same types as (int32, T...).
    name: A name for the operation (optional).
    hostmem: A list of integer. If i is in the list, inputs[i] is a
      host memory tensor. In other words, (i+1)-th argument of the body
      function is expecting a host memory.
    rewrite_with_while: If True, using While op to implement the For.

  Returns:
    A list of `Tensor` objects. Has the same type as `input`.
    A list of output tensors whose types are T.
  """
  if rewrite_with_while:
    return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
  if body.captured_inputs:
    wrapper_name = "%s_BodyWrapper" % body.name

    @function.Defun(*body.declared_input_types, func_name=wrapper_name)
    def BodyWrapper(*args):
      """A wrapper for body that handles loop-carried captured inputs."""
      body_result = body(*args)
      extra_args = tuple(function.get_extra_args())
      # Nullary functions return an Operation. Normal functions can't do this
      # because their return values are converted to Tensors.
      if isinstance(body_result, ops.Operation):
        return extra_args
      # Unary functions return a single Tensor value.
      elif not isinstance(body_result, tuple):
        return (body_result,) + extra_args
      # N-ary functions return a tuple of Tensors.
      else:
        return body_result + extra_args

    inputs += BodyWrapper.captured_inputs
    ret = gen_functional_ops._for(
        start, limit, delta, inputs, BodyWrapper, name=name)
    # Slice off the loop-carried captured inputs.
    ret = ret[:-len(BodyWrapper.captured_inputs)]
  else:
    ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
  if hostmem:
    num_for_params = 3  # start/limit/delta

    input_attr = attr_value_pb2.AttrValue()
    input_attr.list.i.extend([num_for_params + i for i in hostmem])
    ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access

    output_attr = attr_value_pb2.AttrValue()
    output_attr.list.i.extend(hostmem)
    ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
  return ret
# pylint: enable=invalid-name,protected-access


def partitioned_call(args, f):
  return gen_functional_ops.partitioned_call(
      args=args, Tout=[o.type for o in f.definition.signature.output_arg], f=f)
